package com.gitee.jmash.rbac.client.shiro.authz;

import com.gitee.jmash.rbac.client.shiro.JmashJsonWebToken;
import io.smallrye.jwt.algorithm.KeyEncryptionAlgorithm;
import io.smallrye.jwt.algorithm.SignatureAlgorithm;
import io.smallrye.jwt.auth.principal.DefaultJWTCallerPrincipal;
import io.smallrye.jwt.auth.principal.DefaultJWTTokenParser;
import io.smallrye.jwt.auth.principal.JWTAuthContextInfo;
import io.smallrye.jwt.auth.principal.JWTParser;
import io.smallrye.jwt.auth.principal.ParseException;
import io.smallrye.jwt.util.KeyUtils;
import java.security.Key;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.util.Collections;
import java.util.Set;
import javax.crypto.SecretKey;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Typed;
import jakarta.inject.Inject;
import org.eclipse.microprofile.jwt.Claims;
import org.eclipse.microprofile.jwt.JsonWebToken;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.jwt.consumer.JwtContext;

/**
 * Jmash JWT Parser 代替 DefaultJWTParser.
 *
 * @author CGD
 *
 */
@ApplicationScoped
@Typed(JmashJWTParser.class)
public class JmashJWTParser implements JWTParser {

  private static final String ED_EC_PUBLIC_KEY_INTERFACE = "java.security.interfaces.EdECPublicKey";
  private static final String XEC_PRIVATE_KEY_INTERFACE = "java.security.interfaces.XECPrivateKey";

  private final DefaultJWTTokenParser parser = new DefaultJWTTokenParser();

  @Inject
  private JWTAuthContextInfo authContextInfo;

  @Override
  public JmashJsonWebToken parse(final String token) throws ParseException {
    return parse(token, authContextInfo);
  }

  @Override
  public JmashJsonWebToken parse(String bearerToken, JWTAuthContextInfo context)
      throws ParseException {
    JwtContext jwtContext = parser.parse(bearerToken, authContextInfo);
    String type = jwtContext.getJoseObjects().get(0).getHeader("typ");
    return new JmashJsonWebToken(type, jwtContext.getJwtClaims());
  }

  @Override
  public JmashJsonWebToken verify(String bearerToken, PublicKey key) throws ParseException {
    JWTAuthContextInfo newAuthContextInfo = copyAuthContextInfo();
    newAuthContextInfo.setPublicVerificationKey(key);
    if (key instanceof ECPublicKey) {
      setSignatureAlgorithmIfNeeded(newAuthContextInfo, "ES", SignatureAlgorithm.ES256);
    } else if (isEdECPublicKey(key)) {
      setSignatureAlgorithmIfNeeded(newAuthContextInfo, "EdDSA", SignatureAlgorithm.EDDSA);
    } else {
      setSignatureAlgorithmIfNeeded(newAuthContextInfo, "RS", SignatureAlgorithm.RS256);
    }
    return parse(bearerToken, newAuthContextInfo);
  }

  @Override
  public JmashJsonWebToken verify(String bearerToken, SecretKey key) throws ParseException {
    JWTAuthContextInfo newAuthContextInfo = copyAuthContextInfo();
    newAuthContextInfo.setSecretVerificationKey(key);
    setSignatureAlgorithmIfNeeded(newAuthContextInfo, "HS", SignatureAlgorithm.HS256);
    return parse(bearerToken, newAuthContextInfo);
  }

  @Override
  public JmashJsonWebToken verify(String bearerToken, String secret) throws ParseException {
    return verify(bearerToken, KeyUtils.createSecretKeyFromSecret(secret));
  }

  @Override
  public JmashJsonWebToken decrypt(String bearerToken, PrivateKey key) throws ParseException {
    JWTAuthContextInfo newAuthContextInfo = copyAuthContextInfo();
    newAuthContextInfo.setPrivateDecryptionKey(key);
    if (key instanceof ECPrivateKey || isXecPrivateKey(key)) {
      setKeyEncryptionAlgorithmIfNeeded(newAuthContextInfo, "EC",
          KeyEncryptionAlgorithm.ECDH_ES_A256KW);
    } else {
      setKeyEncryptionAlgorithmIfNeeded(newAuthContextInfo, "RS", KeyEncryptionAlgorithm.RSA_OAEP);
    }
    return parse(bearerToken, newAuthContextInfo);
  }

  @Override
  public JmashJsonWebToken decrypt(String bearerToken, SecretKey key) throws ParseException {
    JWTAuthContextInfo newAuthContextInfo = copyAuthContextInfo();
    newAuthContextInfo.setSecretDecryptionKey(key);
    setKeyEncryptionAlgorithmIfNeeded(newAuthContextInfo, "A256KW", KeyEncryptionAlgorithm.A256KW);
    return parse(bearerToken, newAuthContextInfo);
  }

  @Override
  public JmashJsonWebToken decrypt(String bearerToken, String secret) throws ParseException {
    return decrypt(bearerToken, KeyUtils.createSecretKeyFromSecret(secret));
  }

  private JWTAuthContextInfo copyAuthContextInfo() {
    return authContextInfo != null ? new JWTAuthContextInfo(authContextInfo)
        : new JWTAuthContextInfo();
  }


  private void setSignatureAlgorithmIfNeeded(JWTAuthContextInfo newAuthContextInfo, String algoStart,
      SignatureAlgorithm newAlgo) {
    Set<SignatureAlgorithm> algo = newAuthContextInfo.getSignatureAlgorithm();
    if (!algo.stream().anyMatch(s -> s.getAlgorithm().startsWith(algoStart))) {
        newAuthContextInfo.setSignatureAlgorithm(Set.of(newAlgo));
    }
  }

  private void setKeyEncryptionAlgorithmIfNeeded(JWTAuthContextInfo newAuthContextInfo,
      String algoStart, KeyEncryptionAlgorithm newAlgo) {
    Set<KeyEncryptionAlgorithm> algo = newAuthContextInfo.getKeyEncryptionAlgorithm();
    if (algo == null || !algo.stream().anyMatch(s -> s.getAlgorithm().startsWith(algoStart))) {
      newAuthContextInfo.setKeyEncryptionAlgorithm(Collections.singleton(newAlgo));
    }
  }

  private static boolean isEdECPublicKey(Key verificationKey) {
    return KeyUtils.isSupportedKey(verificationKey, ED_EC_PUBLIC_KEY_INTERFACE);
  }

  private static boolean isXecPrivateKey(Key encKey) {
    return KeyUtils.isSupportedKey(encKey, XEC_PRIVATE_KEY_INTERFACE);
  }

  @Override
  public JsonWebToken parseOnly(String token) throws ParseException {
    try {
      JwtClaims claims = new JwtConsumerBuilder().setSkipSignatureVerification()
          .setSkipAllValidators().build().processToClaims(token);
      claims.setClaim(Claims.raw_token.name(), token);
      return new DefaultJWTCallerPrincipal(claims);
    } catch (InvalidJwtException e) {
    }
    return null;
  }


}
